{
 "metadata": {
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.9-final"
  },
  "orig_nbformat": 2,
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.9 64-bit",
   "metadata": {
    "interpreter": {
     "hash": "08bf4b89015df3fa62ec3e5cebfe3c3e5a181176f7e37ebd57af46c3a62ed99e"
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2,
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "AlexNet(\n  (conv): Sequential(\n    (0): Conv2d(1, 96, kernel_size=(11, 11), stride=(4, 4))\n    (1): ReLU()\n    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n    (3): Conv2d(96, 256, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))\n    (4): ReLU()\n    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n    (6): Conv2d(256, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (7): ReLU()\n    (8): Conv2d(384, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (9): ReLU()\n    (10): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))\n    (11): ReLU()\n    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)\n  )\n  (fc): Sequential(\n    (0): Linear(in_features=6400, out_features=4096, bias=True)\n    (1): ReLU()\n    (2): Dropout(p=0.5, inplace=False)\n    (3): Linear(in_features=4096, out_features=4096, bias=True)\n    (4): ReLU()\n    (5): Dropout(p=0.5, inplace=False)\n    (6): Linear(in_features=4096, out_features=10, bias=True)\n  )\n)\n"
     ]
    }
   ],
   "source": [
    "import time\n",
    "import torch \n",
    "from torch import nn, optim\n",
    "import torchvision\n",
    "import d2lzh_pytorch as d2l\n",
    "\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "\n",
    "class AlexNet(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(AlexNet, self).__init__()\n",
    "        self.conv = nn.Sequential(\n",
    "            nn.Conv2d(1, 96, 11, 4),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(3, 2),\n",
    "            nn.Conv2d(96, 256, 5, 1, 2),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(3, 2),\n",
    "            nn.Conv2d(256, 384, 3, 1, 1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(384, 384, 3, 1, 1),\n",
    "            nn.ReLU(),\n",
    "            nn.Conv2d(384, 256, 3, 1, 1),\n",
    "            nn.ReLU(),\n",
    "            nn.MaxPool2d(3, 2)\n",
    "        )\n",
    "\n",
    "        self.fc = nn.Sequential(\n",
    "            nn.Linear(256*5*5, 4096),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(4096, 4096),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(0.5),\n",
    "            nn.Linear(4096, 10)\n",
    "        )\n",
    "    \n",
    "    def forward(self, img):\n",
    "        feature = self.conv(img)\n",
    "        output = self.fc(feature.view(img.shape[0], -1))\n",
    "        return output\n",
    "\n",
    "net = AlexNet()\n",
    "print(net)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "training on  cuda\n",
      "epoch 1, loss 0.6751, train acc 0.740, test acc 0.841, time 131.1 sec\n",
      "epoch 2, loss 0.3590, train acc 0.866, test acc 0.865, time 125.4 sec\n",
      "epoch 3, loss 0.3085, train acc 0.884, test acc 0.889, time 125.5 sec\n",
      "epoch 4, loss 0.2744, train acc 0.899, test acc 0.893, time 125.3 sec\n",
      "epoch 5, loss 0.2550, train acc 0.906, test acc 0.889, time 125.2 sec\n"
     ]
    }
   ],
   "source": [
    "batch_size = 128\n",
    "train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)\n",
    "\n",
    "lr, num_epochs = 0.001, 5\n",
    "optimizer = torch.optim.Adam(net.parameters(), lr=lr)\n",
    "d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ]
}